Skip to content

[CK Tile] Unification work - mma transformations pipeline#5508

Open
chris-tsiaousis-hpc wants to merge 16 commits intodevelopfrom
users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline
Open

[CK Tile] Unification work - mma transformations pipeline#5508
chris-tsiaousis-hpc wants to merge 16 commits intodevelopfrom
users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline

Conversation

@chris-tsiaousis-hpc
Copy link
Contributor

Motivation

In this PR we showcase how the amdgcn structs could be used in a pipeline that does some extra pre/post processing.
For the sparse intrinsics, so far we compressed the A vector "on the fly" right before the execution of the builtin. This might introduce performance issues down the line if, for example, the user decided to chain multiple sparse builtins. We tackle this problem by creating a specific SparseCompressTransform.

A MmaPipelineBase is also created to facilitate those kind of higher level compositions of the amdgcn structs and is integrated to the existing WaveWiseMma prototype. There is an effort to facilitate future operations, like swizzle A/B, C transpose or double/quad attr num access through the MmaPipelineOptionFlags, but those are not yet defined and should do so in a future PR.
The pipeline base class is basically at the RFC stage.

We also create a runtime test for the existing WaveWiseMma, as well as one for the SparseMma pipeline.

Technical Details

The goal should be to have the pipeline easily expandable. May the CRTP of the base class or the interface in general be insufficient or unable to handle all of our needs, then a design modification should be discussed.

Test Plan

New tests are added.

Test Result

Tests should pass.

Submission Checklist

@wj-laskowski wj-laskowski self-requested a review March 17, 2026 09:15
@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline branch from 12f78be to a96f4c8 Compare March 17, 2026 09:54
idx = detail::compress_a_impl<fp16_t, CompressedSize>(v);

VecCompressed result;
__builtin_memcpy(&result, &v, sizeof(VecCompressed));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing a

        idx = detail::compress_a_impl<ScalarT, CompressedSize>(v);
        return *static_cast<const VecCompressed*>(&v);

here gives us an error:

sparse_transforms.hpp:79:17: error: static_cast from '_Float16 * __attribute__((ext_vector_type(16)))' to 'const VecCompressed *' (aka 'const ck_tile::impl::ext_vector<_Float16, 8>::type *') is not allowed
   79 |         return *static_cast<const VecCompressed*>(&v);

Can we avoid copying ? Ideas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the buffer earlier used reinterpret_cast, maybe it works here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to use what the original code did:

const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};

With some if constexpr or some sort of template for different sizes. There may be some perf reason why it was implemented like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reinterpret_cast gives a similar error, unfortunately.
What the "original" code did is still creating a new vector and copying over elements to it. It is also not universal since the compressed A vector doesn't have the same size for GFX9 and GFX12...
I will investigate a way to change the size of the ext_vector without unnecessary copies.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see... it looks like it's gonna be hard to deal with with w/o creating a copy. If there is nothing we can do, the original code will be the safest option to go with

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment here: #5508 (comment)

@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline branch from a96f4c8 to 06106c2 Compare March 17, 2026 10:29
@chris-tsiaousis-hpc chris-tsiaousis-hpc added the organization: streamhpc contributors from streamhpc label Mar 17, 2026
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
struct MmaTransformsDefaultSelector;
struct MmaTransformsDefaultSelector
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not good. It was done because DEVICE and HOST code are weirdly intertwined and I'll think of a way to revert this.

C_TRANSPOSE = 0x1,
SWIZZLE_A = 0x2,
SWIZZLE_B = 0x4,
DOUBLE_ATTR_NUM_ACCESS = 0x8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if all these operations belong here and whether we should treat them the same as compress_a. Did we discuss this already? I thought *ATTR_NUM_ACCESS can always be computed deterministically, i.e. not a pipeline option

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think Swizzle and AttrNumAccess are just modifiers on the tile distribution encoding and do not affect the vector fragments. Same for CTranspose except the vector fragments get swapped. I don't know if there is a need for explicit transposition anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the CTranspose can be kept as a pipeline implementation in the base class. Thoughts?

@wj-laskowski wj-laskowski self-requested a review March 18, 2026 09:55
Copy link
Contributor

@wj-laskowski wj-laskowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I like the new approach, got a few questions

idx = detail::compress_a_impl<fp16_t, CompressedSize>(v);

VecCompressed result;
__builtin_memcpy(&result, &v, sizeof(VecCompressed));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the buffer earlier used reinterpret_cast, maybe it works here?

* non‑zeros are found, remaining fields default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this to exist separately from the existing compress_a function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the idea is that the compress_a function will be removed when we integrate this to CK Tile. I also reverted the changes done to that function to minimize untested regression.

template <typename VecType>
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v, int32_t& idx)
{
using VecTraits = vector_traits<std::decay_t<VecType>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CK library generally uses remove_cvref_t<> instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't use that as it is C++20, didn't know ck has its own implementation. I'll use that, thanks!

COMPRESS_A = 0x20,
};

struct MmaPipelineOptionFlags
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very verbose but I guess necessary if we really don't want to allow raw enums...

Comment on lines +127 to +134
template <typename VecTA, typename VecTB, typename VecTC>
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
{
auto pre = Derived::template preApply<Flags>(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
Derived::execImpl(pre);
return Derived::template postApply<Flags>(pre);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha we are making a second-order wrapper for the intrinsic just like in CK Tile, making more sense to me now.

@krithalith
Copy link
Contributor

Nice work, design makes sense to me an should be compatible with CK Tile, will need to change the amdgcn_mma_base and the tile distribution enc calc slightly if we go with this.

I don't understand this comment though:

This might introduce performance issues down the line if, for example, the user decided to chain multiple sparse builtins.

How would you chain sparse ops? The result is not sparse.

@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline branch from 9f7a0ab to e2e5634 Compare March 18, 2026 16:02
idx = detail::compress_a_impl<ScalarT, CompressedSize>(v);

// TODO c++20: Use bit_cast
return *std::launder(reinterpret_cast<VecCompressed*>(&v));
Copy link
Contributor Author

@chris-tsiaousis-hpc chris-tsiaousis-hpc Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also this thread of comments: #5508 (comment)

I used reinterpret_cast with std::launder. TBF this might have implications that I (and we) cannot foresee. You can't simply remove UB with an std::launder. Or maybe I'm wrong and that's the correct way to go for C++17. I'll research the internals of std::bit_cast to make an educated guess.
This code should definitely be removed when we stop caring about C++17.

For the time being this segment is tested and works as expected.

@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline branch from 2f30a5b to 2067a2d Compare March 19, 2026 11:26
@chris-tsiaousis-hpc chris-tsiaousis-hpc changed the base branch from develop to users/krithalith/ck/unification_policy_struct_refactor March 19, 2026 11:29
@chris-tsiaousis-hpc chris-tsiaousis-hpc marked this pull request as ready for review March 19, 2026 12:42
@chris-tsiaousis-hpc chris-tsiaousis-hpc requested a review from a team as a code owner March 19, 2026 12:42
Copy link
Contributor

@wj-laskowski wj-laskowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for addressing my comments!

@krithalith krithalith force-pushed the users/krithalith/ck/unification_policy_struct_refactor branch from 9859fad to 4314d69 Compare March 20, 2026 08:45
@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline branch from c9a3f75 to cd0a702 Compare March 20, 2026 08:56
static_assert(Flags == MmaPipelineOptionFlags(MmaPipelineOptionFlag::COMPRESS_A));
static_assert(MmaPipelineOptionFlags(Flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A));
static_assert(TransposeC ==
MmaPipelineOptionFlags(Flags).testFlag(MmaPipelineOptionFlag::C_TRANSPOSE));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, isn't this always true?

// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
auto pre = Derived::template preApply<Flags_>(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
hasFlag<MmaPipelineOptionFlag::C_TRANSPOSE>() ? std::forward<VecTB>(b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works only if A and B have the same type and size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, CTranspose will not be available for all intrinsics. Also I don't think CTranspose is possible for sparse intrinsics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll disable them for sparse then!

namespace sparse::detail {
// TODO: c++20: return MmaPipelineOptionFlags directly
template <bool TransposeC>
constexpr inline int getFlags()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be very clear about the distinction between MmaOp flags and Pipeline flags. Also maybe we can come up with a better name than "pipeline", because pipeline already exists as a term for the wavegroup-sized full pipeline 1-2 levels higher.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without giving much thought on it I am thinking MmaComposition but composition makes me think of using different mma structs. So it is not really a composition.
If you guys have any ideas please reach out!

Base automatically changed from users/krithalith/ck/unification_policy_struct_refactor to develop March 20, 2026 15:07
The default implementation should include the pass through transforms and is needed to avoid instantiations of an undefined template.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
This is because this filename was too general for what it did. Also transfered basic components to the reusable base class.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…d A vector size)

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
We should, in the future remove the `::Type`.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…unction

Also added a test for this.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Allow pipeline instantiations on the host side without having to constexpr conditionaly use them within the kernel. Also utilize the warning added to the host/unsupported amdgcn struct.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…support

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the users/chris-tsiaousis-hpc/ck/mma-transformations-pipeline branch from cd0a702 to 4c8c262 Compare March 20, 2026 16:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants